/*Copyright 2009,2010 Alex Graves

This file is part of RNNLIB.

RNNLIB is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

RNNLIB is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with RNNLIB.  If not, see <http://www.gnu.org/licenses/>.*/

#ifndef _INCLUDED_StringAlignment_h
#define _INCLUDED_StringAlignment_h

#include <vector>
#include <map>
#include <iostream>
#include "Helpers.hpp"

using namespace std;

template<class R1, class R2> struct StringAlignment {
  //data
  map<typename boost::range_value<R1>::type, map<typename boost::range_value<R1>::type, int> >
  subsMap;
  map<typename boost::range_value<R1>::type, int> delsMap;
  map<typename boost::range_value<R1>::type, int> insMap;
  Vector<Vector<int> > matrix;
  int substitutions;
  int deletions;
  int insertions;
  int distance;
  int subPenalty;
  int delPenalty;
  int insPenalty;
  size_t n;
  size_t m;

  //functions
  StringAlignment(const R1 &reference_sequence, const R2 &test_sequence,
                  bool trackErrors = false, bool backtrace = true, int sp = 1, int dp = 1,
                  int ip = 1):
    subPenalty(sp),
    delPenalty(dp),
    insPenalty(ip),
    n(reference_sequence.size()),
    m(test_sequence.size()) {
    if (n == 0) {
      substitutions = 0;
      deletions = 0;
      insertions = m;
      distance = m;
    } else if (m == 0) {
      substitutions = 0;
      deletions = n;
      insertions = 0;
      distance = n;
    } else {
      //initialise the matrix
      matrix.resize(n + 1);
      LOOP(Vector<int> &v, matrix) {
        v.resize(m + 1);
        fill(v, 0);
      }
      LOOP(int i, span(n + 1)) {
        matrix[i][0] = i;
      }
      LOOP(int j, span(m + 1)) {
        matrix[0][j] = j;
      }

      //calculate the insertions, substitutions and deletions
      LOOP(int i, span(1, n + 1)) {
        const typename boost::range_value<R1>::type &s_i = reference_sequence[i - 1];
        LOOP(int j, span(1, m + 1)) {
          const typename boost::range_value<R2>::type &t_j = test_sequence[j - 1];
          int cost = ((s_i == t_j) ? 0 : 1);
          const int above = matrix[i - 1][j];
          const int left = matrix[i][j - 1];
          const int diag = matrix[i - 1][j - 1];
          const int cell = min(above + 1,			// deletion
                               min(left + 1,			// insertion
                                   diag + cost));		// substitution

          matrix[i][j] = cell;
        }
      }

      //N.B sub,ins and del penalties are all set to 1 if backtrace is ignored
      if (backtrace) {
        size_t i = n;
        size_t j = m;
        substitutions = 0;
        deletions = 0;
        insertions = 0;

        // Backtracking
        while (i != 0 && j != 0) {
          if (matrix[i][j] == matrix[i - 1][j - 1]) {
            --i;
            --j;
          } else if (matrix[i][j] == matrix[i - 1][j - 1] + 1) {
            if (trackErrors) {
              ++subsMap[reference_sequence[i]][test_sequence[j]];
            }

            ++substitutions;
            --i;
            --j;
          } else if (matrix[i][j] == matrix[i - 1][j] + 1) {
            if (trackErrors) {
              ++delsMap[reference_sequence[i]];
            }

            ++deletions;
            --i;
          } else {
            if (trackErrors) {
              ++insMap[test_sequence[j]];
            }

            ++insertions;
            --j;
          }
        }

        while (i != 0) {
          if (trackErrors) {
            ++delsMap[reference_sequence[i]];
          }

          ++deletions;
          --i;
        }

        while (j != 0) {
          if (trackErrors) {
            ++insMap[test_sequence[j]];
          }

          ++insertions;
          --j;
        }

        // Sanity check:
        check((substitutions + deletions + insertions) == matrix[n][m],
              "Found path with distance " + str(substitutions + deletions + insertions) +
              " but Levenshtein distance is " + str(matrix[n][m]));

        //scale individual errors by penalties
        distance = (subPenalty * substitutions) + (delPenalty * deletions) +
                   (insPenalty * insertions);
      } else {
        distance = matrix[n][m];
      }
    }
  }
  ~StringAlignment() {}
};

#endif
